Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapt custom allreduce for tensorrt llm #2511

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

yizhang2077
Copy link
Collaborator

Motivation

adapt for tensorrt llm custom allreduce, currently still use vllm distributed.
After this pr is merged and sgl-kernel is stable, we only need replace vllm.distribued to sglang.srt.distributed, and add a monkey patch, then we can remove vllm distributed

Modifications

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

test/srt/test_custom_allreduce.py Dismissed Show dismissed Hide dismissed
test/srt/test_custom_allreduce.py Dismissed Show dismissed Hide dismissed
@zhyncs
Copy link
Member

zhyncs commented Dec 18, 2024

@yizhang2077 Could you paste the unit test result

@yizhang2077
Copy link
Collaborator Author

yizhang2077 commented Dec 18, 2024

@yizhang2077 Could you paste the unit test result

there are only correctness test here, do we need compare with vllm?

@zhyncs
Copy link
Member

zhyncs commented Dec 18, 2024

^ gentle ping cc @merrymercy

@zhyncs
Copy link
Member

zhyncs commented Dec 18, 2024

@yizhang2077 Could you paste the unit test result

there are only correctness test here, do we need compare with vllm?

ref #2481 (comment)

@zhyncs
Copy link
Member

zhyncs commented Dec 18, 2024

The most important case we care about is TP=8 and bs in [1, 1024]. The size is about 0 - 32MB. Can we do a more comprehensive test?

@merrymercy
Copy link
Contributor

I think we can merge this one as it has correctness test. We can benchmark the performance part in future PRs.

Condition for switching to this

  • Faster than or equal to vllm's custom allreduce on all cases (TP=2,4,8) x (bs=1,2,4,8, .. 128, 1024)
  • Does not break AMD support

@yizhang2077 yizhang2077 mentioned this pull request Dec 31, 2024
3 tasks
@yizhang2077
Copy link
Collaborator Author

yizhang2077 commented Jan 1, 2025

benchmark test
python -m sglang.bench_one_batch --model-path Meta-Llama-3.1-8B-Instruct --batch 1 2 4 8 16 32 64 128 256 512 --input-len 128 --output-len 1024 --run-name test_run --tp x

TL;DR:
there is still a little bit slower than vllm in most cases, maybe since vllm do some optimization special for cuda graph mode to save a cost of cudamemcpy.

vllm result

tp size batch size median decode latency(s) median decode throughput
2/4/8 1 0.00771/0.00494/0.00373 129.73/202.57/268.30
2/4/8 2 0.00789/0.00517/0.00411 253.59/386.75/486.38
2/4/8 4 0.00803/0.00525/0.00428 498.40/761.53/933.78
2/4/8 8 0.00840/0.00541/0.00485 952.47/1479.86/1648.54
2/4/8 16 0.00884/0.00586/0.00560 1810.33/2726.23/2855.09
2/4/8 32 0.00943/0.00679/0.00545 3394.14/4710.55/5871.54
2/4/8 64 0.01082/0.00749/0.00594 5912.68/8541.01/10782.27
2/4/8 128 0.01485/0.01006/0.007821 8618.89/12725.68/16366.52
2/4/8 256 0.02451/0.01699/0.01717 10446.48/15066.25/14909.56
2/4/8 512 0.04029/0.02787/0.01925 12707.00/18372.62/26596.86

custom allreduce result

tp size batch size median decode latency(s) median decode throughput
2/4/8 1 0.00774/0.00508/0.00401 129.17/196.75/248.93
2/4/8 2 0.00793/0.00529/0.00437 252.20/378.00/457.52
2/4/8 4 0.00805/0.00538/0.00456 496.88/744.01/877.56
2/4/8 8 0.00840/0.00557/0.00504 951.98/1434.99/1586.42
2/4/8 16 0.00882/0.00599/0.00571 1812.92/2669.09/2801.30
2/4/8 32 0.00948/0.00692/0.00571 3374.59/4627.39/5596.83
2/4/8 64 0.01101/0.00768/0.00603 5809.91/8323.83/10597.94
2/4/8 128 0.01524/0.00997/0.00808 8395.69/12833.05/15831.76
2/4/8 256 0.02471/0.01617/0.01685 10356.20/15829.66/15190.94
2/4/8 512 0.04038/0.02786/0.01926 12679.31/18377.18/26573.49

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants